package com.cat2bug.junit.clazz;

import com.cat2bug.junit.Cat2BugSpringAutoRunner;
import com.cat2bug.junit.annotation.Authentication;
import com.cat2bug.junit.annotation.Copy;
import com.cat2bug.junit.service.FunctionTestClassReportService;
import com.cat2bug.junit.service.ParameterService;
import com.cat2bug.junit.util.ParamMethodUtil;
import com.cat2bug.junit.vo.HttpInterfaceVo;
import javassist.*;
import javassist.bytecode.CodeAttribute;
import javassist.bytecode.LocalVariableAttribute;
import javassist.bytecode.MethodInfo;
import org.apache.commons.logging.Log;
import org.junit.Test;
import org.junit.platform.commons.util.StringUtils;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.test.context.web.WebAppConfiguration;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.context.WebApplicationContext;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.*;

public class SpringControllerTestClassFactory {
	ClassPool pool = ClassPool.getDefault();

	public Class<?> createTestClass(Class<?> testCaseClass, Class<?> clazz) throws Exception {
		String proxyClassName = clazz.getSimpleName() + "Test"; // 测试类的类名
		String packageName = clazz.getPackage().getName(); // 测试类的包名
		String longProxyClassName = packageName + "." + proxyClassName;// 测试类+包名

		ITestClassFactory factory = new TestClassFactory(proxyClassName, packageName);
		// 添加@RunWith注解
		Map<String, Object> runWithParams = new HashMap<>();
		runWithParams.put("value", Cat2BugSpringAutoRunner.class);
		factory = new AddAnnotationOfTestClass(factory, RunWith.class, runWithParams);
		// 将单元测试类上的注解添加到动态创建的测试类上
		Annotation[] anns = testCaseClass.getAnnotations();
		for (Annotation ann : anns) {
			if (ann instanceof RunWith) {
				continue;
			}
			Map<String, Object> annParams = new HashMap<>();
			Method[] annMethods = ann.annotationType().getDeclaredMethods();
			for(Method annMethod : annMethods) {
				Object retAnnValue = annMethod.invoke(ann);
				annParams.put(annMethod.getName(), retAnnValue);
			}
			factory = new AddAnnotationOfTestClass(factory, ann.annotationType(),annParams);
		}
		// 添加@SpringBootTest注解
		SpringBootTest springBootTest = testCaseClass.getAnnotation(SpringBootTest.class);
		if (springBootTest == null) {
			factory = new AddAnnotationOfTestClass(factory, SpringBootTest.class);
		}
		// 添加@WebAppConfiguration注解
		// 只有WebEnvironment=MOCK时，才需要WebAppConfiguration注解
		if (testCaseClass.getAnnotation(WebAppConfiguration.class) == null &&
				springBootTest.webEnvironment() == SpringBootTest.WebEnvironment.MOCK) {

			factory = new AddAnnotationOfTestClass(factory, WebAppConfiguration.class);
		}

		// 添加日志对象
		factory = new AddFieldOfTestClass(factory, Log.class, "log");
		// 添加WebContext对象
		Map<Class<? extends Annotation>, Map<String, Object>> webContextAnnotationParams = new HashMap<>();
		webContextAnnotationParams.put(Autowired.class, null);
		factory = new AddFieldOfTestClass(factory, WebApplicationContext.class, "webContext",
				webContextAnnotationParams);
		// 添加MockMvc对象
		factory = new AddFieldOfTestClass(factory, MockMvc.class, "mock");

		// 构造函数
		factory = new AbstractAddConstructorOfTestClass(factory) {
			@Override
			public String body() {
				return "{ this.log=org.apache.commons.logging.LogFactory.getLog(\"" + longProxyClassName + "\"); }";
			}
		};
		// @before函数
		Map<Class<? extends Annotation>, Map<String, Object>> beforeAnnotationParams = new HashMap<>();
		beforeAnnotationParams.put(Autowired.class, null);
		factory = new AbstractAddMethodOfTestClass(factory, "before", null, beforeAnnotationParams) {
			@Override
			public String body(CtClass ctClass) {
				// 在测试启动前初始化MockMvc对象
				return "{mock = org.springframework.test.web.servlet.setup.MockMvcBuilders.webAppContextSetup(webContext).build();}";
			}
		};

		// 鉴权函数
		Optional<Annotation> authentication = Arrays.stream(anns).filter(a->a instanceof Authentication).findFirst();
		if(authentication.isPresent()) {
			Authentication auth = (Authentication)authentication.get();
			if(StringUtils.isNotBlank(auth.name()) && StringUtils.isNotBlank(auth.password())) {
				Map<Class<? extends Annotation>, Map<String, Object>> authenticationManagerAnnotationParams = new HashMap<>();
				authenticationManagerAnnotationParams.put(Autowired.class, null);
				factory = new AddFieldOfTestClass(factory, AuthenticationManager.class, "authenticationManager", authenticationManagerAnnotationParams);
				Map<Class<? extends Annotation>, Map<String, Object>> securityAnnotationParams = new HashMap<>();
				securityAnnotationParams.put(Autowired.class, null);
				factory = new AbstractAddMethodOfTestClass(factory, "cat2bugAuthentication", null, securityAnnotationParams) {
					@Override
					public String body(CtClass ctClass) {
						StringBuffer sb = new StringBuffer();
						sb.append("{");
						sb.append("try {");
						sb.append(String.format("org.springframework.security.authentication.UsernamePasswordAuthenticationToken authenticationToken = new org.springframework.security.authentication.UsernamePasswordAuthenticationToken(\"%s\", \"%s\");",
								auth.name(),auth.password()));
						sb.append("org.springframework.security.core.context.SecurityContext securityContext = org.springframework.security.core.context.SecurityContextHolder.createEmptyContext();");
						sb.append("securityContext.setAuthentication(authenticationManager.authenticate(authenticationToken));");
						sb.append("org.springframework.security.core.context.SecurityContextHolder.setContext(securityContext);");
						sb.append("} catch (Exception e) {");
						sb.append("log.error(e);");
						sb.append("}");
						sb.append("}");
						return sb.toString();
					}
				};
			}
		}

		Set<Method> methods = scanControllerMethod(clazz); // 获取待测试类中的方法集合

		// 从测试用例中查找拼配的生成参数的方法，添加到测试类中
		for (Method m : methods) {
			CtClass srcClass = pool.getCtClass(m.getDeclaringClass().getName()); // 获取原始类
			CtMethod srcMethod = srcClass.getDeclaredMethod(m.getName()); // 获取原始类方法
			MethodInfo methodInfo = srcMethod.getMethodInfo();
			CodeAttribute codeAttribute = methodInfo.getCodeAttribute();
			if (codeAttribute != null) {
				LocalVariableAttribute attr = (LocalVariableAttribute) codeAttribute
						.getAttribute(LocalVariableAttribute.tag);
				int paramLen = srcMethod.getParameterTypes().length; // 参数数量
				Object[][] ans = srcMethod.getParameterAnnotations(); // 获取参数注解
				int pos = Modifier.isStatic(srcMethod.getModifiers()) ? 0 : 1; // 非静态的成员函数的第一个参数是this
				for (int i = 0; i < paramLen; i++) {
					// 是否过滤掉指定类型不处理
					if(ParameterService.getInstance().isFilter(srcMethod.getParameterTypes()[i].getName())){
						continue;
					}
					String paramName = attr.variableName(i + pos); // 参数名称
					CtClass paramType = srcMethod.getParameterTypes()[i]; // 参数类型
					String methodName = ParamMethodUtil.createMethodName(m.getName(),paramName,paramType.getName());
					factory= new AddArgeMethodOfTestClass(factory,methodName, clazz, longProxyClassName, m ,paramName, paramType, ans[i]);
				}
			}
		}

		// 添加测试方法
		factory = this.addTestMethod(factory,methods);
		// 拷贝测试用例中带Copy注解的方法到测试类
		factory = this.copyMethod(factory, testCaseClass);
		// 拷贝测试用例中带Copy注解的字段到测试类
		factory = this.copyField(factory, testCaseClass);

		CtClass ctClass = factory.createTestClass(clazz);
		this.writeFile(ctClass);
		Class<?> cs = ctClass.toClass();
		return cs;
	}

	/**
	 * 添加测试方法
	 * @param factory	工厂类
	 * @param methods	待测试的接口方法
	 * @return			工厂类
	 */
	private ITestClassFactory addTestMethod(ITestClassFactory factory, Set<Method> methods) {
		Map<Class<? extends Annotation>, Map<String, Object>> testMethodAnnotationParams = new HashMap<>();
		testMethodAnnotationParams.put(Test.class, null);
		for (Method m : methods) {
			String testMethodName = "test" + m.getName().substring(0, 1).toUpperCase() + m.getName().substring(1);
			// 根据controller的不同类型注解选择不同的方法工厂创建测试方法
			if (m.getAnnotation(GetMapping.class) != null) {
				factory = new AddHttpGetOfTestMethod(factory, testMethodName, m, null, testMethodAnnotationParams);
			} else if (m.getAnnotation(PostMapping.class) != null) {
				factory = new AddHttpPostOfTestMethod(factory, testMethodName, m, null, testMethodAnnotationParams);
			} else if (m.getAnnotation(PutMapping.class) != null) {
				factory = new AddHttpPutOfTestMethod(factory, testMethodName, m, null, testMethodAnnotationParams);
			} else if (m.getAnnotation(DeleteMapping.class) != null) {
				factory = new AddHttpDeleteOfTestMethod(factory, testMethodName, m, null, testMethodAnnotationParams);
			} else if (m.getAnnotation(PatchMapping.class) != null) {
//				factory = new AddHttpGetOfTestMethod(factory, testMethodName,m, null, testMethodAnnotationParams);
			} else if (m.getAnnotation(RequestMapping.class) != null) {
				RequestMapping rms = m.getAnnotation(RequestMapping.class);
				if (rms.method().length > 0) {
					for (RequestMethod rm : rms.method()) {
						switch (rm) {
							case GET:
								factory = new AddHttpGetOfTestMethod(factory, testMethodName, m, null,
										testMethodAnnotationParams);
								break;
							case HEAD:
								break;
							case POST:
								factory = new AddHttpPostOfTestMethod(factory, testMethodName, m, null,
										testMethodAnnotationParams);
								break;
							case PUT:
								factory = new AddHttpPutOfTestMethod(factory, testMethodName, m, null,
										testMethodAnnotationParams);
								break;
							case PATCH:
								break;
							case DELETE:
								factory = new AddHttpDeleteOfTestMethod(factory, testMethodName, m, null,
										testMethodAnnotationParams);
								break;
							case OPTIONS:
								break;
							case TRACE:
								break;
						}
					}
				} else {
					factory = new AddHttpGetOfTestMethod(factory, testMethodName, m, null, testMethodAnnotationParams);
				}
			}
		}
		return factory;
	}
	/**
	 * 拷贝带有Copy注解的方法
	 * @param factory		类工厂
	 * @param testCaseClass 测试用例类
	 * @return				类工厂
	 * @throws NotFoundException	异常
	 */
	private ITestClassFactory copyMethod(ITestClassFactory factory, Class<?> testCaseClass) throws NotFoundException {
		// 拷贝带有CopyMethod注解的方法到测试类
		for(Method cm : testCaseClass.getMethods()){
			if(cm.getAnnotation(Copy.class)!=null) {
				CtClass srcClass = pool.getCtClass(cm.getDeclaringClass().getName()); // 获取原始类
				CtMethod srcMethod = srcClass.getDeclaredMethod(cm.getName()); // 获取原始类方法
				factory = new CopyMethodOfTestClass(factory, srcMethod);
			}
		}
		return factory;
	}

	/**
	 * 拷贝测试用例中带Copy注解的字段到测试类
	 * @param factory		类工厂
	 * @param testCaseClass	测试用例类
	 * @return				类工厂
	 * @throws NotFoundException	异常
	 */
	private ITestClassFactory copyField(ITestClassFactory factory, Class<?> testCaseClass) throws NotFoundException, ClassNotFoundException {
		CtClass srcClass = pool.getCtClass(testCaseClass.getName()); // 获取原始类
		CtField[] fields = srcClass.getDeclaredFields(); // 获取原始类
		// 拷贝带有CopyMethod注解的方法到测试类
		for(CtField f : fields){
			if(f.getAnnotation(Copy.class)!=null) {
				factory = new CopyFiledOfTestClass(factory,testCaseClass, f);
			}
		}
		return factory;
	}

	/**
	 * 将类写到指定文件
	 * @param ctClass 类对象
	 */
	private void writeFile(CtClass ctClass) {
		try {
			ctClass.writeFile("./target/cat2bug-junit/classes");
		}catch (Exception e) {
			e.printStackTrace();
		}
	}
	/**
	 * 扫描Controller类中的接口方法
	 * 
	 * @param testClass 需要测试的类
	 * @return 接口方法类集合
	 */
	private Set<Method> scanControllerMethod(Class<?> testClass) {
		Set<Method> ret = new HashSet<>();
		Method[] methods = testClass.getMethods();
		for (Method m : methods) {
			if (m.getAnnotation(GetMapping.class) != null || m.getAnnotation(PostMapping.class) != null
					|| m.getAnnotation(PutMapping.class) != null || m.getAnnotation(DeleteMapping.class) != null
					|| m.getAnnotation(RequestMapping.class) != null || m.getAnnotation(PatchMapping.class) != null) {
				ret.add(m);
			}
		}
		return ret;
	}
}
